import logging
import uuid
from time import sleep
from unittest.mock import AsyncMock, Mock  # patch

import jwt
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient

from whyhow_api.config import Settings, SettingsAPI  # SettingsAuth0
from whyhow_api.dependencies import get_db, get_settings
from whyhow_api.middleware import RateLimiter


# Mocking the MongoDB client and database
async def async_mock_db_success():
    mock_db = AsyncMock()
    mock_db.command = AsyncMock(return_value={"ok": 1.0})
    yield mock_db


async def async_mock_db_failure():
    mock_db = AsyncMock()
    mock_db.command = AsyncMock(
        return_value={"ok": 0.0}
    )  # Simulate a failure response
    yield mock_db


def test_root_no_basicauth(client):
    response = client.get("/")
    assert response.status_code == 200

    rb = response.json()
    assert rb.startswith("Welcome")


def test_root_correlation_id_autogenerated(client):

    response = client.get("/")
    assert response.status_code == 200

    rb = response.json()
    assert rb.startswith("Welcome")

    assert "x-request-id" in response.headers
    assert len(response.headers["x-request-id"]) == 32


def test_root_correlation_id_provided_valid(client):

    x_request_id = str(uuid.uuid4())
    response = client.get("/", headers={"x-request-id": x_request_id})
    assert response.status_code == 200

    rb = response.json()
    assert rb.startswith("Welcome")

    assert "x-request-id" in response.headers
    assert response.headers["x-request-id"] == x_request_id


def test_root_correlation_id_provided_invalid(client):
    # correlation id needs to be 32 characters long

    response = client.get("/", headers={"x-request-id": 20 * "b"})
    assert response.status_code == 200

    rb = response.json()
    assert rb.startswith("Welcome")

    assert "x-request-id" in response.headers
    assert response.headers["x-request-id"] != 20 * "b"


@pytest.fixture
def add_broken_route(client):
    """Temporarily add a route that raises an exception."""

    @client.app.get("/extra")
    def extra():
        raise ValueError("Extra route")

    yield

    # deregister the endpoint
    client.app.router.routes = client.app.routes[:-1]


def test_correlation_id_exception(client, add_broken_route):
    response = client.get("/extra")
    assert response.status_code == 500
    assert "x-request-id" in response.headers
    assert len(response.headers["x-request-id"]) == 32

    rb = response.json()

    assert rb == {"detail": "Internal Server Error"}


def test_database_connection_success(client):
    client.app.dependency_overrides[get_db] = async_mock_db_success
    response = client.get("/db")
    assert response.status_code == 200
    assert "Connected to database cluster" in response.text


def test_database_connection_failure(client):
    client.app.dependency_overrides[get_db] = async_mock_db_failure

    response = client.get("/db")
    assert (
        response.status_code == 200
    )  # Still HTTP 200, but the message indicates a problem
    assert "Problem connecting to database cluster." in response.text


@pytest.mark.parametrize("log_level", ["DEBUG", "INFO", "ERROR"])
def test_lifespan(client, monkeypatch, log_level, caplog):
    caplog.set_level(logging.NOTSET, logger="whyhow_api")

    client.app.dependency_overrides[get_settings] = lambda: Settings(
        dev={"log_level": log_level},
        mongodb={"username": "test", "password": "test", "host": "test"},
        # api={
        #     "auth0": {
        #         "domain": "test",
        #     }
        # },
    )

    fake_connect_to_mongo = Mock()
    fake_close_mongo_connection = Mock()

    monkeypatch.setattr(
        "whyhow_api.main.connect_to_mongo", fake_connect_to_mongo
    )
    monkeypatch.setattr(
        "whyhow_api.main.close_mongo_connection",
        fake_close_mongo_connection,
    )

    # assertions
    assert logging.getLogger("whyhow_api").level == logging.NOTSET
    with client:
        assert logging.getLogger("whyhow_api").level == getattr(
            logging, log_level
        )

        fake_connect_to_mongo.assert_called_once()
        fake_close_mongo_connection.assert_not_called()

    fake_close_mongo_connection.assert_called_once()


def test_settings_endpoint(client):
    settings_dump = client.get("/settings").json()

    assert set(settings_dump.keys()) == {
        "api",
        "aws",
        "dev",
        "generative",
        "embedding",
        "mongodb",
        "logfire",
    }


test_settings = Settings(
    api=SettingsAPI(
        limit_frequency_value=1,  # tokens added per second
        bucket_capacity=1,  # max tokens in bucket
        excluded_paths=["/"],
        # auth0=SettingsAuth0(
        #     domain="test-domain",
        #     audience="test-audience",
        #     algorithm="test-algorithm",
        # ),
    ),
)


def create_mock_jwt():
    payload = {"sub": "user123"}
    secret = "test-secret"
    token = jwt.encode(payload, secret, algorithm="HS256")
    return token


# def test_rate_limit_bearer_token(monkeypatch):
#     test_app = FastAPI()
#     test_app.add_middleware(RateLimiter)

#     @test_app.get("/test")
#     def mock_endpoint():
#         return {"message": "Hello World"}

#     test_client = TestClient(test_app)

#     monkeypatch.setattr(
#         "whyhow_api.middleware.get_settings",
#         lambda: test_settings,
#     )

#     # Create and use the mock JWT token
#     mock_token = create_mock_jwt()
#     mock_oauth2 = AsyncMock(return_value=mock_token)
#     monkeypatch.setattr(
#         "whyhow_api.middleware.OAuth2AuthorizationCodeBearer.__call__",
#         mock_oauth2,
#     )

#     mock_jwks_client = Mock()
#     mock_jwks_client.get_signing_key_from_jwt.return_value.key = "mock_key"

#     # Create a mock request state
#     mock_state = Mock()
#     mock_state.jwks_client = mock_jwks_client

#     # Patch the request to include the mocked app.state
#     with patch("fastapi.Request.app", new_callable=Mock) as mock_app:
#         mock_app.state = mock_state

#         monkeypatch.setattr(
#             "whyhow_api.middleware.jwt.decode",
#             lambda token, key, algorithms, audience, issuer: {
#                 "sub": "user123"
#             },
#         )

#         headers = {"Authorization": f"Bearer {mock_token}"}

#         # Send a request which should pass
#         response = test_client.get("/test", headers=headers)
#         assert (
#             response.status_code == 200
#         ), f"Expected 200, got {response.status_code}"
#         print(response.json())

#         # Immediately send another request which should be rate limited
#         response = test_client.get("/test", headers=headers)
#         assert (
#             response.status_code == 429
#         ), f"Expected 429, got {response.status_code}"

#         # Wait for 1 second to reset the rate limit
#         sleep(1)

#         # After a pause, another request should succeed
#         response = test_client.get("/test", headers=headers)
#         assert (
#             response.status_code == 200
#         ), f"Expected 200, got {response.status_code}"


# def test_rate_limit_bearer_token_jwks_client_exception(monkeypatch):
#     test_app = FastAPI()
#     test_app.add_middleware(RateLimiter)

#     @test_app.get("/test")
#     def mock_endpoint():
#         return {"message": "Hello World"}

#     test_client = TestClient(test_app)

#     monkeypatch.setattr(
#         "whyhow_api.middleware.get_settings",
#         lambda: test_settings,
#     )

#     # Create and use the mock JWT token
#     mock_token = create_mock_jwt()
#     mock_oauth2 = AsyncMock(return_value=mock_token)
#     monkeypatch.setattr(
#         "whyhow_api.middleware.OAuth2AuthorizationCodeBearer.__call__",
#         mock_oauth2,
#     )

#     mock_jwks_client = Mock(raise_exception=Exception("JWKS Client Exception"))
#     monkeypatch.setattr(
#         "whyhow_api.middleware.jwt.PyJWKClient", lambda url: mock_jwks_client
#     )

#     headers = {"Authorization": f"Bearer {mock_token}"}

#     with pytest.raises(Exception) as exc_info:
#         # Send a request which should pass
#         response = test_client.get("/test", headers=headers)
#         assert (
#             response.status_code == 401
#         ), f"Expected 401, got {response.status_code}"
#     assert (
#         "unable to authorize" in exc_info.value.detail.lower()
#     ), 'Expected "Unable to authorize" in exception message'
#     assert 401 == exc_info.value.status_code, "Expected 401 status code"


def test_rate_limit_api_key(monkeypatch):
    test_app = FastAPI()
    test_app.add_middleware(RateLimiter)

    @test_app.get("/test")
    def mock_endpoint():
        return {"message": "Hello World"}

    test_client = TestClient(test_app)

    monkeypatch.setattr(
        "whyhow_api.middleware.get_settings",
        lambda: test_settings,
    )
    headers = {"x-api-key": "testkey"}

    # Send a request which should pass
    response = test_client.get("/test", headers=headers)
    assert (
        response.status_code == 200
    ), f"Expected 200, got {response.status_code}"
    print(response.json())

    # Immediately send another request which should be rate limited
    response = test_client.get("/test", headers=headers)
    assert (
        response.status_code == 429
    ), f"Expected 429, got {response.status_code}"

    # Wait for 1 second to reset the rate limit
    sleep(1)

    # After a pause, another request should succeed
    response = test_client.get("/test", headers=headers)
    assert (
        response.status_code == 200
    ), f"Expected 200, got {response.status_code}"


def test_rate_limit_excluded_path(monkeypatch):
    test_app = FastAPI()
    test_app.add_middleware(RateLimiter)

    @test_app.get("/")
    def mock_endpoint():
        return "public endpoint"

    test_client = TestClient(test_app)

    monkeypatch.setattr(
        "whyhow_api.middleware.get_settings",
        lambda: test_settings,
    )
    # Send a request which should pass
    response = test_client.get("/")
    assert (
        response.status_code == 200
    ), f"Expected 200, got {response.status_code}"


def test_rate_limit_no_user_key(monkeypatch):
    test_app = FastAPI()
    test_app.add_middleware(RateLimiter)

    @test_app.get("/test")
    def mock_endpoint():
        return {"message": "Hello World"}

    test_client = TestClient(test_app)

    monkeypatch.setattr(
        "whyhow_api.middleware.get_settings",
        lambda: test_settings,
    )

    with pytest.raises(Exception):
        # Send a request which should pass
        response = test_client.get("/test")
        assert (
            response.status_code == 401
        ), f"Expected 401, got {response.status_code}"
